import os
import numpy as np
import torch
import sys
import multiprocessing
from config import cfg

if __name__ == '__main__':
    #------------prepare enviroment------------
    seed = cfg.SEED
    if seed is not None:
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

    gpus = cfg.GPU_ID
    if len(gpus)==1:
        torch.cuda.set_device(gpus[0])

    torch.backends.cudnn.benchmark = True


    #------------prepare data loader------------
    data_mode = cfg.DATASET
    if data_mode is 'SHHA':
        from datasets.SHHA.loading_data import loading_data
        from datasets.SHHA.setting import cfg_data
    elif data_mode is 'SHHB':
        from datasets.SHHB.loading_data import loading_data
        from datasets.SHHB.setting import cfg_data
    elif data_mode is 'QNRF':
        from datasets.QNRF.loading_data import loading_data
        from datasets.QNRF.setting import cfg_data
    elif data_mode is 'UCF50':
        from datasets.UCF50.loading_data import loading_data
        from datasets.UCF50.setting import cfg_data
    elif data_mode is 'WE':
        from datasets.WE.loading_data import loading_data
        from datasets.WE.setting import cfg_data
    elif data_mode is 'GCC':
        from datasets.GCC.loading_data import loading_data
        from datasets.GCC.setting import cfg_data
    elif data_mode is 'Mall':
        from datasets.Mall.loading_data import loading_data
        from datasets.Mall.setting import cfg_data
    elif data_mode is 'UCSD':
        from datasets.UCSD.loading_data import loading_data
        from datasets.UCSD.setting import cfg_data


    #------------Prepare Trainer------------
    net = cfg.NET
    if net in ['MCNN', 'AlexNet', 'VGG', 'VGG_DECODER', 'Res50', 'Res101',
               'CSRNet','Res101_SFCN', 'SCAR', 'DANet',
               'Stackpool_Base', 'Stackpool_Deep', 'Stackpool_Wide']:
        from trainer import Trainer
    elif net in ['SANet']:
        from trainer_for_M2TCC import Trainer # double losses but signle output
    elif net in ['CMTL']:
        from trainer_for_CMTL import Trainer # double losses and double outputs
    # elif net in ['PCCNet']:
    #     from trainer_for_M3T3OCC import Trainer

    #------------Start Training------------
    if sys.platform.startswith('win'):
        # Hack for multiprocessing.freeze_support() to work from a
        # setuptools-generated entry point.
        multiprocessing.freeze_support()

    pwd = os.path.split(os.path.realpath(__file__))[0]
    cc_trainer = Trainer(loading_data,cfg_data,pwd)
    cc_trainer.forward()
